import argparse
import numpy as np
from llava.model.builder import load_pretrained_model
from llava.mm_utils import tokenizer_image_token
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
from llava.conversation import conv_templates
from PIL import Image
import torch
import copy
import json
import os
from tqdm import tqdm
from decord import VideoReader, cpu
import warnings

warnings.filterwarnings("ignore")

def process_video(video_path, max_frames_num, fps=1, force_sample=False, keyframe_indices=None):
    if max_frames_num == 0:
        return np.zeros((1, 336, 336, 3)), "", 0
    vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
    total_frame_num = len(vr)
    video_time = total_frame_num / vr.get_avg_fps()
    fps = round(vr.get_avg_fps() / fps)
    
    if keyframe_indices is not None:
        frame_idx = keyframe_indices
        frame_time = [i / fps for i in frame_idx]
        frame_time = ",".join([f"{i:.2f}s" for i in frame_time])
        spare_frames = vr.get_batch(frame_idx).asnumpy()
        return spare_frames, frame_time, video_time
    
    frame_idx = [i for i in range(0, len(vr), fps)]
    frame_time = [i / fps for i in frame_idx]
    if len(frame_idx) > max_frames_num or force_sample:
        sample_fps = max_frames_num
        uniform_sampled_frames = np.linspace(0, total_frame_num - 1, sample_fps, dtype=int)
        frame_idx = uniform_sampled_frames.tolist()
        frame_time = [i / vr.get_avg_fps() for i in frame_idx]
    frame_time = ",".join([f"{i:.2f}s" for i in frame_time])
    spare_frames = vr.get_batch(frame_idx).asnumpy()
    return spare_frames, frame_time, video_time

pretrained = "lmms-lab/LLaVA-Video-7B-Qwen2"
model_name = "llava_qwen"
device = "cuda"
device_map = "auto"
overwrite_config = {}
overwrite_config['mm_vision_tower'] = "google/siglip-so400m-patch14-384" 
tokenizer, model, image_processor, max_length = load_pretrained_model(
    pretrained, None, model_name, torch_dtype="bfloat16",
    device_map=device_map, overwrite_config=overwrite_config
)
model.eval()

parser = argparse.ArgumentParser()
parser.add_argument("--max_frames", type=int, default=16)
parser.add_argument("--add_time_instruction", type=bool, default=False)
parser.add_argument('--dataset_path', type=str, default='datasets/videomme/data') # replace your dataset path here
args = parser.parse_args()
max_frames_num = args.max_frames
conv_template = "chatml_direct"
data_path = args.dataset_path

def llava_inference(qs, video, video_time, frame_time):
    if args.add_time_instruction:
        time_instruction = f"The video lasts for {video_time:.2f} seconds, and {video[0].shape[0]} frames are uniformly sampled from it. These frames are located at {frame_time}. Please answer the following questions related to this video."
        question = DEFAULT_IMAGE_TOKEN + f"{time_instruction}\n" + qs
    else:
        question = DEFAULT_IMAGE_TOKEN * len(video) + "\n" + qs
    conv = copy.deepcopy(conv_templates[conv_template])
    conv.append_message(conv.roles[0], question)
    conv.append_message(conv.roles[1], None)
    prompt_question = conv.get_prompt()
    input_ids = tokenizer_image_token(prompt_question, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(device)
    cont = model.generate(
        input_ids,
        images=video,
        modalities=["video"],
        do_sample=False,
        temperature=0,
        max_new_tokens=16,
    )
    text_outputs = tokenizer.batch_decode(cont, skip_special_tokens=True)[0].strip()
    return text_outputs

rep_list = []


os.makedirs("results", exist_ok=True)
json_file = f"results/efs_videomme_{max_frames_num}frames.json"

with open("./videomme_16frames_selected_by_efs.json", 'r', encoding='utf-8') as file:
    mme_data = json.load(file)

if os.path.exists(json_file):
    with open(json_file, "r", encoding='utf-8') as file:
        rep_list = json.load(file)

index = len(rep_list)

for item in tqdm(mme_data[index:], desc="Processing items"):
    video_path = os.path.join(data_path, item['url'] + ".mp4")
    content = item.copy()
    for question in content['questions']:
        keyframe_indices = question["keyframe_indices"]
        video, frame_time, video_time = process_video(video_path, max_frames_num, 1, force_sample=True, keyframe_indices=keyframe_indices)
        video_tensor = image_processor.preprocess(video, return_tensors="pt")["pixel_values"].cuda().bfloat16()
        video_tensor = [video_tensor]
        qs = (
            "Select the best answer to the following multiple-choice question based on the video. Respond with only the letter (A, B, C, or D) of the correct option."
            + '\n' + question["question"] + '\n'
            + "\n".join(question["options"]) + '\n' + "\nAnswer with the option's letter from the given choices directly."
        )
        res = llava_inference(qs, video_tensor, video_time, frame_time)
        question["response"] = res

    rep_list.append(content)

    with open(json_file, "w", encoding='utf-8') as file:
        json.dump(rep_list, file, ensure_ascii=False, indent=4)